{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%%shell\n",
        "# Installs the latest dev build of TVM from PyPI. If you wish to build\n",
        "# from source, see https://tvm.apache.org/docs/install/from_source.html\n",
        "pip install apache-tvm --pre"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# 4. microTVM PyTorch 教程\n",
        "**Authors**:\n",
        "[Mehrdad Hessar](https://github.com/mehrdadh)\n",
        "\n",
        "本教程展示了使用 PyTorch 模型的 MicroTVM 主机驱动 AoT 编译。此教程可以在使用 C 运行时（CRT）的 x86 CPU 上执行。\n",
        "\n",
        "```{note}\n",
        "此教程仅在使用 CRT 的 x86 CPU 上运行，无法在 Zephyr 上运行，因为该模型不适合我们当前支持的 Zephyr 开发板。\n",
        "```"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 安装 microTVM Python 依赖\n",
        "\n",
        "TVM 不包括 Python 串行通信包，因此在使用 microTVM 之前我们必须安装一个。我们还需要 TFLite 来加载模型。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%%shell\n",
        "pip install pyserial==3.5 tflite==2.1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import pathlib\n",
        "import torch\n",
        "import torchvision\n",
        "from torchvision import transforms\n",
        "import numpy as np\n",
        "from PIL import Image\n",
        "\n",
        "import tvm\n",
        "from tvm import relay\n",
        "from tvm.contrib.download import download_testdata\n",
        "from tvm.relay.backend import Executor\n",
        "import tvm.micro.testing"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 加载预训练 PyTorch 模型\n",
        "\n",
        "首先，从 torchvision 中加载预训练的 MobileNetV2。接下来，下载一张猫的图片并对其进行预处理以用作模型输入。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = torchvision.models.quantization.mobilenet_v2(weights=\"DEFAULT\", quantize=True)\n",
        "model = model.eval()\n",
        "\n",
        "input_shape = [1, 3, 224, 224]\n",
        "input_data = torch.randn(input_shape)\n",
        "scripted_model = torch.jit.trace(model, input_data).eval()\n",
        "\n",
        "img_url = \"https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true\"\n",
        "img_path = download_testdata(img_url, \"cat.png\", module=\"data\")\n",
        "img = Image.open(img_path).resize((224, 224))\n",
        "\n",
        "# Preprocess the image and convert to tensor\n",
        "my_preprocess = transforms.Compose(\n",
        "    [\n",
        "        transforms.Resize(256),\n",
        "        transforms.CenterCrop(224),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
        "    ]\n",
        ")\n",
        "img = my_preprocess(img)\n",
        "img = np.expand_dims(img, 0)\n",
        "\n",
        "input_name = \"input0\"\n",
        "shape_list = [(input_name, input_shape)]\n",
        "relay_mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 定义 Target，Runtime 和 Executor\n",
        "\n",
        "在本教程中，使用 AOT 主机驱动执行器。为了将模型编译为在 x86 机器上模拟的嵌入式环境，使用 C 运行时（CRT），并使用 `host` 微目标。使用这种设置，TVM 编译了用于 C 运行时的模型，该模型可以在与物理微控制器相同的流程下在 x86 CPU 机器上运行。CRT 使用 `src/runtime/crt/host/main.cc` 中的 main() 函数。要使用物理硬件，请将 `board` 替换为另一个物理微目标，例如 `nrf5340dk_nrf5340_cpuapp` 或 `mps2_an521`，并将平台类型更改为 Zephyr。在 [Training Vision Models for microTVM on Arduino](tutorial-micro-train-arduino) 和 [microTVM TFLite Tutorial](tutorial_micro_tflite) 中可以找到更多目标示例。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "target = tvm.micro.testing.get_target(platform=\"crt\", board=None)\n",
        "\n",
        "# Use the C runtime (crt) and enable static linking by setting system-lib to True\n",
        "runtime = tvm.relay.backend.Runtime(\"crt\", {\"system-lib\": True})\n",
        "\n",
        "# Use the AOT executor rather than graph or vm executors. Don't use unpacked API or C calling style.\n",
        "executor = Executor(\"aot\")"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 编译模型\n",
        "\n",
        "现在，将模型编译为目标平台："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "with tvm.transform.PassContext(\n",
        "    opt_level=3,\n",
        "    config={\"tir.disable_vectorize\": True},\n",
        "):\n",
        "    module = tvm.relay.build(\n",
        "        relay_mod, target=target, runtime=runtime, executor=executor, params=params\n",
        "    )"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 创建 microTVM project\n",
        "\n",
        "既然将编译后的模型作为 IRModule，需要创建固件（firmware）项目，以便使用 microTVM 来使用编译后的模型。为此，使用 Project API。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "ename": "MicroTVMTemplateProjectNotFoundError",
          "evalue": "",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mMicroTVMTemplateProjectNotFoundError\u001b[0m      Traceback (most recent call last)",
            "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m template_project_path \u001b[39m=\u001b[39m pathlib\u001b[39m.\u001b[39mPath(tvm\u001b[39m.\u001b[39;49mmicro\u001b[39m.\u001b[39;49mget_microtvm_template_projects(\u001b[39m\"\u001b[39;49m\u001b[39mcrt\u001b[39;49m\u001b[39m\"\u001b[39;49m))\n\u001b[1;32m      2\u001b[0m project_options \u001b[39m=\u001b[39m {\u001b[39m\"\u001b[39m\u001b[39mverbose\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mFalse\u001b[39;00m, \u001b[39m\"\u001b[39m\u001b[39mworkspace_size_bytes\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m6\u001b[39m \u001b[39m*\u001b[39m \u001b[39m1024\u001b[39m \u001b[39m*\u001b[39m \u001b[39m1024\u001b[39m}\n\u001b[1;32m      4\u001b[0m temp_dir \u001b[39m=\u001b[39m tvm\u001b[39m.\u001b[39mcontrib\u001b[39m.\u001b[39mutils\u001b[39m.\u001b[39mtempdir() \u001b[39m/\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mproject\u001b[39m\u001b[39m\"\u001b[39m\n",
            "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/micro/build.py:106\u001b[0m, in \u001b[0;36mget_microtvm_template_projects\u001b[0;34m(platform)\u001b[0m\n\u001b[1;32m    104\u001b[0m         \u001b[39mbreak\u001b[39;00m\n\u001b[1;32m    105\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 106\u001b[0m     \u001b[39mraise\u001b[39;00m MicroTVMTemplateProjectNotFoundError()\n\u001b[1;32m    108\u001b[0m \u001b[39mreturn\u001b[39;00m os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(microtvm_template_projects, platform)\n",
            "\u001b[0;31mMicroTVMTemplateProjectNotFoundError\u001b[0m: "
          ]
        }
      ],
      "source": [
        "template_project_path = pathlib.Path(tvm.micro.get_microtvm_template_projects(\"crt\"))\n",
        "project_options = {\"verbose\": False, \"workspace_size_bytes\": 6 * 1024 * 1024}\n",
        "\n",
        "temp_dir = tvm.contrib.utils.tempdir() / \"project\"\n",
        "project = tvm.micro.generate_project(\n",
        "    str(template_project_path),\n",
        "    module,\n",
        "    temp_dir,\n",
        "    project_options,\n",
        ")"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 构建，烧录并执行模型\n",
        "\n",
        "接下来，构建 microTVM 项目并将其烧录（flash）。如果是通过主机的 `main.cc` 模拟微控制器（microcontroller），或者选择了 Zephyr 模拟板作为目标，则跳过烧录步骤，因为烧录步骤是针对物理微控制器的。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "project.build()\n",
        "project.flash()\n",
        "\n",
        "input_data = {input_name: tvm.nd.array(img.astype(\"float32\"))}\n",
        "with tvm.micro.Session(project.transport()) as session:\n",
        "    aot_executor = tvm.runtime.executor.aot_executor.AotModule(session.create_aot_executor())\n",
        "    aot_executor.set_input(**input_data)\n",
        "    aot_executor.run()\n",
        "    result = aot_executor.get_output(0).numpy()"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 查询 synset 名称\n",
        "\n",
        "查询在 1000 个类别的同义词集中预测的 top 1 索引。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "synset_url = (\n",
        "    \"https://raw.githubusercontent.com/Cadene/\"\n",
        "    \"pretrained-models.pytorch/master/data/\"\n",
        "    \"imagenet_synsets.txt\"\n",
        ")\n",
        "synset_name = \"imagenet_synsets.txt\"\n",
        "synset_path = download_testdata(synset_url, synset_name, module=\"data\")\n",
        "with open(synset_path) as f:\n",
        "    synsets = f.readlines()\n",
        "\n",
        "synsets = [x.strip() for x in synsets]\n",
        "splits = [line.split(\" \") for line in synsets]\n",
        "key_to_classname = {spl[0]: \" \".join(spl[1:]) for spl in splits}\n",
        "\n",
        "class_url = (\n",
        "    \"https://raw.githubusercontent.com/Cadene/\"\n",
        "    \"pretrained-models.pytorch/master/data/\"\n",
        "    \"imagenet_classes.txt\"\n",
        ")\n",
        "class_path = download_testdata(class_url, \"imagenet_classes.txt\", module=\"data\")\n",
        "with open(class_path) as f:\n",
        "    class_id_to_key = f.readlines()\n",
        "\n",
        "class_id_to_key = [x.strip() for x in class_id_to_key]\n",
        "\n",
        "# Get top-1 result for TVM\n",
        "top1_tvm = np.argmax(result)\n",
        "tvm_class_key = class_id_to_key[top1_tvm]\n",
        "\n",
        "# Convert input to PyTorch variable and get PyTorch result for comparison\n",
        "with torch.no_grad():\n",
        "    torch_img = torch.from_numpy(img)\n",
        "    output = model(torch_img)\n",
        "\n",
        "    # Get top-1 result for PyTorch\n",
        "    top1_torch = np.argmax(output.numpy())\n",
        "    torch_class_key = class_id_to_key[top1_torch]\n",
        "\n",
        "print(\"Relay top-1 id: {}, class name: {}\".format(top1_tvm, key_to_classname[tvm_class_key]))\n",
        "print(\"Torch top-1 id: {}, class name: {}\".format(top1_torch, key_to_classname[torch_class_key]))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.11"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
